#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" By Martin Senande-Rivera
    For Spatial and temporal expansion of global wildland fire activity in response to climate change """

import os
import glob, sys
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from mpl_toolkits.basemap import Basemap, maskoceans
import matplotlib.colors as mcolors
import datetime as dt
from matplotlib.colors import LinearSegmentedColormap
import cmaps
import cmocean
from pandas import Series, DataFrame
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import FormatStrFormatter
from matplotlib.path import Path
import matplotlib.patches as patches


font = {'size'   : 5}
matplotlib.rc('font', **font)

ruta_data='../Classification/1-Threshold_selection/'

BA = xr.open_dataset(ruta_data+'BA_max.nc')['BurnedArea']
Pp = xr.open_dataset(ruta_data+'P_perce.nc')['__xarray_dataarray_variable__']
Ta = xr.open_dataset(ruta_data+'T_anom.nc')['__xarray_dataarray_variable__']

ruta_gfed4='../DATA/GFED4/'  # GFED4 data path
ds_ba = xr.open_dataset(ruta_gfed4+'GFED4_BA_months.nc')
start_date = dt.datetime.strptime('1996-01', '%Y-%m')
end_date = dt.datetime.strptime('2017-01', '%Y-%m')
dates = pd.date_range(start_date, end_date, freq='m')
BA_serie = ds_ba.BurnedArea

m = Basemap(epsg=4326, resolution='i',	
				llcrnrlat=-90., urcrnrlat=90.,	
				llcrnrlon=-180., urcrnrlon=180.)
lats = BA['lat']
lons = BA['lon']
x, y = np.meshgrid(lons.values, lats.values)

# Africa
verts1 = [
(-5, 4), # left, bottom
(-5, 12), # left, top
(33, 12), # right, top
(33, 4), # right, bottom
(-5, 4), # ignored
]
# Australia
verts2 = [
(118, -32), # left, bottom
(118, -22), # left, top
(144, -22), # right, top
(144, -32), # right, bottom
(118, -32), # ignored
]
# Brasil
verts3 = [
(-53, -19), # left, bottom
(-53, -5), # left, top
(-42, -5), # right, top
(-42, -19), # right, bottom
(-53, -19), # ignored
]
# Canada
verts4 = [
(-127, 52), # left, bottom
(-127, 65), # left, top
(-99, 65), # right, top
(-99, 52), # right, bottom
(-127, 52), # ignored
]
# Iberian Peninsula
verts5 = [
(-10, 35), # left, bottom
(-10, 44), # left, top
(4, 44), # right, top
(4, 35), # right, bottom
(-10, 35), # ignored
]
# Rusia
verts6 = [
(108, 47), # left, bottom
(108, 65), # left, top
(138, 65), # right, top
(138, 47), # right, bottom
(108, 47), # ignored
]
# USA
verts7 = [
(-125, 32), # left, bottom
(-125, 45), # left, top
(-115, 45), # right, top
(-115, 32), # right, bottom
(-125, 32), # ignored
]
codes2 = [Path.MOVETO,
Path.LINETO,
Path.LINETO,
Path.LINETO,
Path.CLOSEPOLY,
]

# BOXES
# Africa = lat=slice(4.,12.),lon=slice(-5.,33.)
# Australia = lat=slice(-32.,-22.),lon=slice(118.,144.)
# Brasil = lat=slice(-19.,-5.),lon=slice(-53.,-42.)
# Canada = lat=slice(52.,65.),lon=slice(-127.,-99.)
# IP = lat=slice(35.,44.),lon=slice(-10.,4.)
# Rusia = lat=slice(47.,65.),lon=slice(108.,138.)
# USA = lat=slice(32.,45.),lon=slice(-125.,-115.)

fig = plt.figure(figsize=(6,4))
gs = fig.add_gridspec(3, 2)
ax1 = fig.add_subplot(gs[0:-1,:])
m.drawcoastlines(linewidth=0.25,zorder=3)
m.drawcountries(linewidth=0.1,zorder=3)
m.drawlsmask(land_color='none',ocean_color="#e4e4e4",lakes=True,zorder=2)

path1 = Path(verts1, codes2)
patch1a = patches.PathPatch(path1, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch1a)
path2 = Path(verts2, codes2)
patch2a = patches.PathPatch(path2, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch2a)
path3 = Path(verts3, codes2)
patch3a = patches.PathPatch(path3, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch3a)
path4 = Path(verts4, codes2)
patch4a = patches.PathPatch(path4, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch4a)
path5 = Path(verts5, codes2)
patch5a = patches.PathPatch(path5, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch5a)
path6 = Path(verts6, codes2)
patch6a = patches.PathPatch(path6, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch6a)
path7 = Path(verts7, codes2)
patch7a = patches.PathPatch(path7, fc=(0,0,0,0.2), ec=(0,0,0,1), lw=0.3, zorder=4)
ax1.add_patch(patch7a)


axis_to_data = ax1.transAxes + ax1.transData.inverted()

#### Timeseries 1 - Africa ####
BA_1 = BA_serie.sel(lat=slice(4.,12.),lon=slice(-5.,33.))
BA_1_mean = BA_1.mean(axis=(1,2))
X1 = 0.66
Y1 = 0.4
points_data1 = axis_to_data.transform((X1,Y1))
m.plot([-5.,points_data1[0]],[4.,points_data1[1]],'k',linewidth=0.3,zorder=4)
points_data1 = axis_to_data.transform((X1+0.11,Y1+0.1))
m.plot([33.,points_data1[0]],[12.,points_data1[1]],'k',linewidth=0.3,zorder=4)
inset_ax1 = ax1.inset_axes([X1,Y1,0.11,0.1]) 
inset_ax1.plot(dates,BA_1_mean,'k',linewidth=0.4)
inset_ax1.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax1.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax1.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax1.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax1.margins(x=0, y=0)
inset_ax1.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax1.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax1.spines[axis].set_linewidth(0.3)
#### Timeseries 2 - Australia ####
BA_2 = BA_serie.sel(lat=slice(-32.,-22.),lon=slice(118.,144.))
BA_2_mean = BA_2.mean(axis=(1,2))
X2 = 0.7
Y2 = 0.2
points_data2 = axis_to_data.transform((X2,Y2+0.1))
m.plot([118.,points_data2[0]],[-22.,points_data2[1]],'k',linewidth=0.3,zorder=4)
points_data2 = axis_to_data.transform((X2+0.14,Y2))
m.plot([144.,points_data2[0]],[-32.,points_data2[1]],'k',linewidth=0.3,zorder=4)
inset_ax2 = ax1.inset_axes([X2,Y2,0.14,0.1]) 
inset_ax2.plot(dates,BA_2_mean,'k',linewidth=0.4)
inset_ax2.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax2.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax2.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax2.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax2.margins(x=0, y=0)
inset_ax2.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax2.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax2.spines[axis].set_linewidth(0.3)
#### Timeseries 3 - Brasil ####
BA_3 = BA_serie.sel(lat=slice(-19.,-5.),lon=slice(-53.,-42.))
BA_3_mean = BA_3.mean(axis=(1,2))
X3 = 0.4
Y3 = 0.2
points_data3 = axis_to_data.transform((X3,Y3))
m.plot([-53.,points_data3[0]],[-19.,points_data3[1]],'k',linewidth=0.3,zorder=4)
points_data3 = axis_to_data.transform((X3+0.14,Y3+0.1))
m.plot([-42.,points_data3[0]],[-5.,points_data3[1]],'k',linewidth=0.3,zorder=4)
inset_ax3 = ax1.inset_axes([X3,Y3,0.14,0.1]) 
inset_ax3.plot(dates,BA_3_mean,'k',linewidth=0.4)
inset_ax3.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax3.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax3.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax3.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax3.margins(x=0, y=0)
inset_ax3.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax3.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax3.spines[axis].set_linewidth(0.3)
#### Timeseries 4 - Canada ####
BA_4 = BA_serie.sel(lat=slice(52.,65.),lon=slice(-127.,-99.))
BA_4_mean = BA_4.mean(axis=(1,2))
X4 = 0.03
Y4 = 0.54
points_data4 = axis_to_data.transform((X4,Y4+0.1))
m.plot([-127.,points_data4[0]],[65.,points_data4[1]],'k',linewidth=0.3,zorder=4)
inset_ax4 = ax1.inset_axes([X4,Y4,0.14,0.1]) 
inset_ax4.plot(dates,BA_4_mean,'k',linewidth=0.4)
inset_ax4.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax4.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax4.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax4.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax4.margins(x=0, y=0)
inset_ax4.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax4.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax4.spines[axis].set_linewidth(0.3)
#### Timeseries 5 - Iberian Peninsula ####
BA_5 = BA_serie.sel(lat=slice(35.,44.),lon=slice(-10.,4.))
BA_5_mean = BA_5.mean(axis=(1,2))
X5 = 0.32
Y5 = 0.62
points_data5 = axis_to_data.transform((X5,Y5+0.1))
m.plot([-10.,points_data5[0]],[44.,points_data5[1]],'k',linewidth=0.3,zorder=4)
points_data5 = axis_to_data.transform((X5+0.12,Y5))
m.plot([4.,points_data5[0]],[35.,points_data5[1]],'k',linewidth=0.3,zorder=4)
inset_ax5 = ax1.inset_axes([X5,Y5,0.12,0.1]) 
inset_ax5.plot(dates,BA_5_mean,'k',linewidth=0.4)
inset_ax5.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax5.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax5.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax5.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax5.margins(x=0, y=0)
inset_ax5.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax5.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax5.spines[axis].set_linewidth(0.3)
#### Timeseries 6 - Rusia ####
BA_6 = BA_serie.sel(lat=slice(47.,65.),lon=slice(108.,138.))
BA_6_mean = BA_6.mean(axis=(1,2))
X6 = 0.87
Y6 = 0.55
points_data6 = axis_to_data.transform((X6,Y6))
m.plot([108.,points_data6[0]],[47.,points_data6[1]],'k',linewidth=0.3,zorder=4)
points_data6 = axis_to_data.transform((X6+0.12,Y6+0.1))
m.plot([138.,points_data6[0]],[65.,points_data6[1]],'k',linewidth=0.3,zorder=4)
inset_ax6 = ax1.inset_axes([X6,Y6,0.12,0.1]) 
inset_ax6.plot(dates,BA_6_mean,'k',linewidth=0.4)
inset_ax6.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax6.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax6.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax6.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax6.margins(x=0, y=0)
inset_ax6.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax6.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax6.spines[axis].set_linewidth(0.3)
#### Timeseries 7 - USA ####
BA_7 = BA_serie.sel(lat=slice(32.,45.),lon=slice(-125.,-115.))
BA_7_mean = BA_7.mean(axis=(1,2))
X7 = 0.1
Y7 = 0.3
points_data7 = axis_to_data.transform((X7+0.14,Y7+0.1))
m.plot([-115.,points_data7[0]],[45.,points_data7[1]],'k',linewidth=0.3,zorder=4)
inset_ax7 = ax1.inset_axes([X7,Y7,0.14,0.1]) 
inset_ax7.plot(dates,BA_7_mean,'k',linewidth=0.4)
inset_ax7.xaxis.set_minor_locator(matplotlib.dates.YearLocator(1))
inset_ax7.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%Y'))
inset_ax7.xaxis.set_major_locator(matplotlib.dates.YearLocator(3))
inset_ax7.yaxis.set_major_formatter(FormatStrFormatter('%d $ha$'))
inset_ax7.margins(x=0, y=0)
inset_ax7.tick_params(axis='both', which='both', length=1, pad=0.6, labelsize=1.5, width=0.3)
inset_ax7.grid(True,which='both',linestyle='-',linewidth=0.3)
for axis in ['top','bottom','left','right']:
    inset_ax7.spines[axis].set_linewidth(0.3)

clevs=np.array([0,100,500,1000,2000,5000,10000,20000])
cmap = cmocean.cm.thermal #mcolors.ListedColormap(["#40a4de", "#328498", "#4fb93d", "#83a10c", "#ffdd2e", "#ea7a17", "#d82e2e"])
norm = mcolors.BoundaryNorm(clevs, cmap.N, extend='max')	
cs2 = m.contourf(x, y, BA, clevs, cmap=cmap, norm=norm, alpha=1., extend='max',zorder=1)  	
axins = inset_axes(ax1,
                   width="2%",  
                   height="85%",  
                   loc='lower left',
                   bbox_to_anchor=(1.01, 0., 1, 1),
                   bbox_transform=ax1.transAxes,
                   borderpad=0,
                   )
cbar = fig.colorbar(cs2,ticks=clevs,cax=axins) 							
ax1.text(0.01, 0.99, 'a',family='sans-serif',weight='bold',size=7, horizontalalignment='left', verticalalignment='top',transform=ax1.transAxes)
ax1.text(1.01, 0.99, 'BA$_{max} (ha)$',family='sans-serif',weight='bold',size=7, horizontalalignment='left', verticalalignment='top',transform=ax1.transAxes)
plt.gcf()

ax2 = fig.add_subplot(gs[2,0])
m.drawcoastlines(linewidth=0.25,zorder=3)
m.drawcountries(linewidth=0.1,zorder=3)
m.drawlsmask(land_color='none',ocean_color="#e4e4e4",lakes=True,zorder=2)
clevs=np.array([0,2,4,6,8,10,12,14])
cmap = cmocean.cm.deep #cmap = mcolors.ListedColormap(["#921519", "#dc3228", "#f98a2d", "#fdce12", "#20be43", "#0fa7e3", "#2f69b1"])
norm = mcolors.BoundaryNorm(clevs, cmap.N, extend='max')	
cs2 = m.contourf(x, y, Pp, clevs, cmap=cmap, norm=norm, alpha=1., extend='max',zorder=1)  	
cs2.cmap.set_over("#1b2c62")
axins = inset_axes(ax2,
                   width="2%",  
                   height="85%",  
                   loc='lower left',
                   bbox_to_anchor=(1.01, 0., 1, 1),
                   bbox_transform=ax2.transAxes,
                   borderpad=0,
                   )
cbar = fig.colorbar(cs2,ticks=clevs,cax=axins) 			
ax2.text(0.01, 0.99, 'b',family='sans-serif',weight='bold',size=7,horizontalalignment='left', verticalalignment='top',transform=ax2.transAxes)
ax1.text(1.01, 0.99, 'PP$_{FS} (\%)$',family='sans-serif',weight='bold',size=5, horizontalalignment='left', verticalalignment='top',transform=ax2.transAxes)
plt.gcf()

ax3 = fig.add_subplot(gs[2,1])
m.drawcoastlines(linewidth=0.25,zorder=3)
m.drawcountries(linewidth=0.1,zorder=3)
m.drawlsmask(land_color='none',ocean_color="#e4e4e4",lakes=True,zorder=2)
clevs=np.array([-15,-10,-5,0,5,10,15])
cmap = cmocean.cm.balance #cmap = mcolors.ListedColormap(["#2f69b1", "#74bbe5", "#c1e7f8", "#fdd778", "#f98a2d", "#dc3228"])
norm = mcolors.BoundaryNorm(clevs, cmap.N, extend='both')	
cs2 = m.contourf(x, y, Ta, clevs, cmap=cmap, norm=norm, alpha=1., extend='both',zorder=1)  	
#cs2.cmap.set_over("#921519")
#cs2.cmap.set_under("#1b2c62")
axins = inset_axes(ax3,
                   width="2%",  
                   height="85%",  
                   loc='lower left',
                   bbox_to_anchor=(1.01, 0., 1, 1),
                   bbox_transform=ax3.transAxes,
                   borderpad=0,
                   )
cbar = fig.colorbar(cs2,ticks=clevs,cax=axins) 	
ax3.text(0.01, 0.99, 'c',family='sans-serif',weight='bold',size=7,horizontalalignment='left', verticalalignment='top',transform=ax3.transAxes)
ax1.text(1.01, 0.99, 'TA$_{FS} (^\circ C)$',family='sans-serif',weight='bold',size=5, horizontalalignment='left', verticalalignment='top',transform=ax3.transAxes)
plt.gcf()

plt.subplots_adjust(top = 0.95, bottom=0.05, hspace=0.1, wspace=0.25)

fig.savefig('Figure1.png', dpi=300, bbox_inches='tight') 						
#fig.savefig('Figure1.pdf', bbox_inches='tight')                         
